# -*- encoding: utf-8 -*-
'''
@File    :   cuda2d_sampling.py
@Time    :   2021/10/09 00:46:04
@Author  :   Ming Ding 
@Contact :   dm18@mails.tsinghua.edu.cn
'''

# here put the import lib
import os
import sys
import math
import random
import torch

import torch
import torch.nn.functional as F
import numpy as  np

def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits[indices_to_remove] = filter_value     
    return logits

class IterativeEntfilterStrategy:
    def __init__(self, invalid_slices=[], temperature=1., topk=6, temperature2=0.9):
        self.invalid_slices = invalid_slices
        self.temperature = temperature
        self.topk = topk        
        self.cluster_labels = torch.tensor(np.load('cluster_label.npy'), device='cuda', dtype=torch.long)
        self.temperature2 = temperature2


    def forward(self, logits_, tokens, temperature=None):
        # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
        if temperature is None:
            temperature = self.temperature 
            
        logits = logits_.float() / temperature
        for invalid_slice in self.invalid_slices:
            logits[..., invalid_slice] = -float('Inf')
        logits = logits.view(-1, logits.shape[-1])
        
        rprobs = F.softmax(logits.float(), dim=-1)
        c = self.cluster_labels.expand(*rprobs.shape)
        cprobs = torch.zeros(logits.shape[0], 500, device=logits.device).scatter_add_(1, c, rprobs)
    
        best_scores, best_clusters = cprobs.topk(self.topk)
        bz = logits.shape[0]
        best_scores = best_scores / best_scores.sum(dim=-1, keepdim=True)
        sampled_ids = torch.multinomial(best_scores, num_samples=1)
        selected_clusters = torch.gather(best_clusters, dim=1, index=sampled_ids)
        selected_mask = (self.cluster_labels.unsqueeze(0).expand(bz, -1) != selected_clusters) # cluster_labels [1, 20000] \in [0,500)
        logits[selected_mask] = -65504
        # for i in range(bz):
        #     selected_cluster = best_clusters[i][torch.multinomial(best_scores[i] / best_scores[i].sum(), num_samples=1)]
        #     logits[i, self.cluster_labels != selected_cluster] = -65504
            
        # logits = top_k_logits(logits, self.topk, self.top_p)
        probs = F.softmax(logits.float() / self.temperature2, dim=-1)  # float is essetial, due to a bug in Pytorch
        pred = torch.multinomial(probs, num_samples=1).view(*logits_.shape[:2])
        
        assert tokens.shape[1] == pred.shape[1] + 1
        tokens = torch.cat((tokens[:, :1], pred), dim=1)
        return tokens

# class IterativeEntfilterStrategy:
#     def __init__(self, invalid_slices=[], temperature=1., topk=40):
#         self.invalid_slices = invalid_slices
#         self.temperature = temperature
#         self.topk = topk

#     def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
#         # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
#         if temperature is None:
#             temperature = self.temperature 
        
#         logits = logits.float() / temperature
#         for invalid_slice in self.invalid_slices:
#             logits[..., invalid_slice] = -float('Inf')
        
#         top_k_logits_(logits, self.topk)
#         probs = F.softmax(logits, dim=-1)
#         pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
#         pred.squeeze_(-1)
        
#         assert tokens.shape[1] == pred.shape[1] + 1
#         tokens = torch.cat((tokens[:, :1], pred), dim=1)
#         return tokens

def filling_sequence_dsr(
        model, 
        seq0,
        seq1, 
        warmup_steps=3,
        block_hw=(4, 4),
        strategy=IterativeEntfilterStrategy(topk=10),
        ):
    '''
        seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
            4095 {layout[2]} final_token.
        Attention:
        The sampling temperature are changing, temporally we hard code them here.
        The temperature in the strategy is not used.
    '''
    assert hasattr(model, 'layout')
    layout = model.layout
    assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \
        and seq0.shape[0] == seq1.shape[0]
    assert len(layout) == 3
    assert seq1.shape[1] == layout[-1] - layout[-2] + 1
    assert (seq1 >= 0).all() and (seq0 >= 0).all()
    device = seq0.device
    # concat and pad sequences
    batch_size = seq0.shape[0]
    n_pad = layout[1] - seq0.shape[1]
    assert n_pad > 0, "You should truncate long input before filling."
    seq = torch.cat((
        torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype)
            .unsqueeze(0).expand(batch_size, n_pad),
        seq0, seq1), dim=1) # [b, layout[-1]+1]
    assert seq.shape[1] == layout[-1] + 1

    # build initial tokens, attention_mask, and position_ids
    tokens = seq.clone()
    attention_mask = torch.ones(layout[1], layout[1]).to(device)
    attention_mask[:layout[0], layout[0]:] = 0
    attention_mask[n_pad:, :n_pad] = 0
    attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
    position_ids = torch.cat((
        torch.zeros(n_pad, dtype=torch.long),
        torch.arange(0, layout[0] - n_pad), 
        torch.arange(513, 513 + layout[1] - layout[0]),
        torch.arange(1024, 1024+layout[2]-layout[1]))).to(device)
    log_attention_weights = torch.zeros(layout[1], layout[1], 
            device=device).type_as(next(model.parameters()))
    log_attention_weights[layout[0]:, n_pad:layout[0]] = 0.

    # prepare for interation
    unfixed = (tokens < 0) # just init an all-False tensor
    unfixed[:, -layout[-1] + layout[-2]:] = True
    
    ll, rr = block_hw
    edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
    num_steps = warmup_steps + ll - 1 + rr
    # interative refining
    
    # unfixed[..., -(layout[-1] - layout[-2]):].view(
    #     batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, :, :, -1] = False
    
    
    ret = []
    ret.append(tokens[:, layout[-2]+1:].clone())
    for step_cnt in range(1, num_steps+1):
        if step_cnt <= warmup_steps:
            logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
            real_temp = 1.
            new_tokens = strategy.forward(logits, tokens, real_temp)
            tokens[unfixed] = new_tokens[unfixed]
        else:
            logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask, log_attention_weights=log_attention_weights)
            real_temp = 1.
            new_tokens = strategy.forward(
                logits, tokens, real_temp,
                entfilter=1.3,
                filter_topk=5,
                temperature2=0.6
            )
            # tokens[unfixed] = new_tokens[unfixed]
            # fixed tokens (update unfixed)
            unfixed2 = (tokens > 10000000)
            for x in range(min(ll, step_cnt - warmup_steps)):
                y = step_cnt - warmup_steps - x - 1
                if y < rr:
                    unfixed[..., -(layout[-1] - layout[-2]):].view(
                        batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
                    unfixed2[..., -(layout[-1] - layout[-2]):].view(
                        batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = True
            tokens[unfixed2] = new_tokens[unfixed2]
                
        ret.append(tokens[:, layout[-2]+1:].clone())

    return ret
